{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Kp4KS9uLT0LI"
      },
      "source": [
        "Copyright 2022 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cshpfsWsxFcs"
      },
      "source": [
        "## Synthetic data generation\n",
        "\n",
        "This file generates synthetic data using the generating functions provided in `data/lsa_synthetic.py`. The output is written to `colab/tmp_data` by default."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IMDdsYh6NEyi"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sklearn\n",
        "import jax\n",
        "import scipy\n",
        "from sklearn.preprocessing import OneHotEncoder\n",
        "from sklearn.metrics import roc_auc_score\n",
        "from sklearn.metrics import accuracy_score\n",
        "from sklearn.metrics import log_loss\n",
        "import re\n",
        "import itertools\n",
        "import os\n",
        "from IPython.display import display\n",
        "from latent_shift_adaptation.data.lsa_synthetic import Simulator, MultiWSimulator"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "i2Fd4jKSNfM8"
      },
      "outputs": [],
      "source": [
        "#@title Library functions\n",
        "\n",
        "def get_squeezed_df(data_dict: dict) -\u003e pd.DataFrame:\n",
        "  \"\"\"Converts a dict of numpy arrays into a DataFrame, extracting columns of arrays into separate DataFrame columns.\"\"\"\n",
        "  temp = {}\n",
        "  for key, value in data_dict.items():\n",
        "    squeezed_array = np.squeeze(value)\n",
        "    if len(squeezed_array.shape) == 1:\n",
        "      temp[key] = squeezed_array\n",
        "    elif len(squeezed_array.shape) \u003e 1:\n",
        "      for i in range(value.shape[1]):\n",
        "        temp[f'{key}_{i}'] = np.squeeze(value[:, i])\n",
        "  return pd.DataFrame(temp)\n",
        "\n",
        "def process_data(data_dict, w_cols=['w_1', 'w_2', 'w_3']):\n",
        "  result = data_dict.copy()\n",
        "  for w_col in w_cols:\n",
        "    result[f'{w_col}_binary'] = 1.0*(result[f'{w_col}'] \u003e 0)\n",
        "    result[f'{w_col}_one_hot'] = OneHotEncoder(sparse=False).fit_transform(result[f'{w_col}_binary'])\n",
        "  result['u_one_hot'] = OneHotEncoder(sparse=False).fit_transform(result['u'].reshape(-1, 1))\n",
        "  return result\n",
        "\n",
        "def generate_data(p_u, seed, num_samples, partition_dict, param_dict=None):\n",
        "\n",
        "  sim = MultiWSimulator(param_dict=param_dict)\n",
        "  samples_dict = {}\n",
        "  for i, (partition_key, partition_frac) in enumerate(partition_dict.items()):\n",
        "    num_samples_partition = int(partition_frac*num_samples)\n",
        "    sim.update_param_dict(num_samples=num_samples_partition, p_u=p_u)\n",
        "    samples_dict[partition_key] = process_data(sim.get_samples(seed=seed + 15*i))\n",
        "  return samples_dict\n",
        "\n",
        "def tidy_w(data_dict, w_value):\n",
        "  result = data_dict.copy()\n",
        "  for key in result.keys():\n",
        "    result[key]['w'] = result[key][f'w_{w_value}']\n",
        "    result[key]['w_binary'] = result[key][f'w_{w_value}_binary']\n",
        "    result[key]['w_one_hot'] = result[key][f'w_{w_value}_one_hot']\n",
        "  return result\n",
        "\n",
        "\n",
        "# Convert data to dataframe format\n",
        "def pack_to_df(samples_dict):\n",
        "  return pd.concat({key: get_squeezed_df(value) for key, value in samples_dict.items()}).reset_index(level=-1, drop=True).rename_axis('partition').reset_index()\n",
        "\n",
        "# Extract dataframe format back to dict format\n",
        "def extract_from_df(samples_df, cols=['u', 'x', 'w', 'c', 'c_logits', 'y', 'y_logits', 'y_one_hot', \n",
        "                                      'u_one_hot', 'x_scaled',\n",
        "                                      'w_1', 'w_1_binary', 'w_1_one_hot',\n",
        "                                      'w_2', 'w_2_binary', 'w_2_one_hot',\n",
        "                                      'w_2_binary', 'w_2_one_hot',\n",
        "                                      ]):\n",
        "  \"\"\"\n",
        "  Extracts dict of numpy arrays from dataframe\n",
        "  \"\"\"\n",
        "  result = {}\n",
        "  for col in cols:\n",
        "    if col in samples_df.columns:\n",
        "      result[col] = samples_df[col].values\n",
        "    else:\n",
        "      match_str = f\"^{col}_\\d$\"\n",
        "      r = re.compile(match_str, re.IGNORECASE)\n",
        "      matching_columns = list(filter(r.match, samples_df.columns))\n",
        "      if len(matching_columns) == 0:\n",
        "        continue\n",
        "      result[col] = samples_df[matching_columns].to_numpy()\n",
        "  return result\n",
        "\n",
        "def extract_from_df_nested(samples_df, cols=['u', 'x', 'w', 'c', 'c_logits', 'y', 'y_logits', 'y_one_hot', 'w_binary', 'w_one_hot', 'u_one_hot', 'x_scaled']):\n",
        "  \"\"\"\n",
        "  Extracts nested dict of numpy arrays from dataframe with structure {domain: {partition: data}}\n",
        "  \"\"\"\n",
        "  result = {}\n",
        "  for domain in samples_df['domain'].unique():\n",
        "    result[domain] = {}\n",
        "    domain_df = samples_df.query('domain == @domain')\n",
        "    for partition in domain_df['partition'].unique():\n",
        "      partition_df = domain_df.query('partition == @partition')\n",
        "      result[domain][partition] = extract_from_df(partition_df, cols=cols)\n",
        "  return result\n",
        "\n",
        "def write_df_to_drive(filename, data, folder_id, overwrite=True):\n",
        "  file_id = drive.SaveFile(filename=filename, data=data.to_csv(index=False), overwrite_warning=1-overwrite)\n",
        "  existing_files = drive.ListFolderWithFileNames(folder_id=folder_id)\n",
        "  existing_files = {key: value for key, value in existing_files.items() if value == filename}\n",
        "  if filename in existing_files.values():\n",
        "    if not overwrite:\n",
        "      print('File exists in folder. Doing nothing')\n",
        "    else:\n",
        "      print('Overwriting file in folder')\n",
        "      for trash_file_id in existing_files.keys():\n",
        "        drive.TrashFile(trash_file_id)\n",
        "      drive.MoveFileToFolder(file_id=file_id, folder_id=folder_id)\n",
        "  else:\n",
        "    drive.MoveFileToFolder(file_id=file_id, folder_id=folder_id)\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fjyHcFNmn3Rh"
      },
      "outputs": [],
      "source": [
        "folder_id = './tmp_data'\n",
        "os.makedirs(folder_id, exist_ok=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lhtXtfXxPYMY"
      },
      "outputs": [],
      "source": [
        "if folder_id is not None:\n",
        "  p_u_range = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]\n",
        "  num_samples = 10000\n",
        "  w_coeff_list = [1, 2, 3]\n",
        "  partition_dict = {'train': 0.7, 'val': 0.1, 'test': 0.2}\n",
        "  seed = 192\n",
        "  result = {}\n",
        "  for i, p_u_0 in enumerate(p_u_range):\n",
        "    p_u = [p_u_0, 1-p_u_0]\n",
        "    samples_dict = generate_data(p_u=p_u, seed=i*seed+i, num_samples=num_samples, partition_dict=partition_dict)\n",
        "    for w_coeff in w_coeff_list:\n",
        "      print(p_u_0, w_coeff)\n",
        "      samples_dict_tidy = tidy_w(samples_dict, w_value=w_coeff)\n",
        "      samples_df = pack_to_df(samples_dict_tidy)\n",
        "      filename = f'synthetic_multivariate_num_samples_10000_w_coeff_{w_coeff}_p_u_0_{p_u_0}.csv'\n",
        "      samples_df.to_csv(os.path.join(folder_id, filename), index=False)\n",
        "    \n",
        "else:\n",
        "  print('folder_id not set. Doing nothing')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
